{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# An Introduction to Factorization Machines with MNIST\n",
    "_**Making a Binary Prediction of Whether a Handwritten Digit is a 0**_\n",
    "\n",
    "1. [Introduction](#Introduction)\n",
    "2. [Prerequisites and Preprocessing](#Prequisites-and-Preprocessing)\n",
    "  1. [Permissions and environment variables](#Permissions-and-environment-variables)\n",
    "  2. [Data ingestion](#Data-ingestion)\n",
    "  3. [Data inspection](#Data-inspection)\n",
    "  4. [Data conversion](#Data-conversion)\n",
    "3. [Training the FM model](#Training-the-FM-model)\n",
    "4. [Set up hosting for the model](#Set-up-hosting-for-the-model)\n",
    "  1. [Import model into hosting](#Import-model-into-hosting)\n",
    "  2. [Create endpoint configuration](#Create-endpoint-configuration)\n",
    "  3. [Create endpoint](#Create-endpoint)\n",
    "5. [Validate the model for use](#Validate-the-model-for-use)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Introduction\n",
    "\n",
    "Welcome to our example introducing Amazon SageMaker's Factorization Machines Algorithm!  Today, we're analyzing the [MNIST](https://en.wikipedia.org/wiki/MNIST_database) dataset which consists of images of handwritten digits, from zero to nine.  We'll use the individual pixel values from each 28 x 28 grayscale image to predict a yes or no label of whether the digit is a 0 or some other digit (1, 2, 3, ... 9).\n",
    "\n",
    "The method that we'll use is a factorization machine binary classifier.  A factorization machine is a general-purpose supervised learning algorithm that you which can use for both classification and regression tasks.  It is an extension of a linear model that is designed to parsimoniously capture interactions between features in high dimensional sparse datasets.  For example, in a click prediction system, the factorization machine model can capture click rate patterns observed when ads from a certain ad-category are placed on pages from a certain page-category.  Factorization machines are a good choice for tasks dealing with high dimensional sparse datasets, such as click prediction and item recommendation.\n",
    "\n",
    "Amazon SageMaker's Factorization Machine algorithm provides a robust, highly scalable implementation of this algorithm, which has become extremely popular in ad click prediction and recommender systems.  The main purpose of this notebook is to quickly show the basics of implementing Amazon SageMaker Factorization Machines, even if the use case of predicting a digit from an image is not where factorization machines shine.\n",
    "\n",
    "To get started, we need to set up the environment with a few prerequisite steps, for permissions, configurations, and so on."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prequisites and Preprocessing\n",
    "\n",
    "### Permissions and environment variables\n",
    "\n",
    "_This notebook was created and tested on an ml.m4.xlarge notebook instance._\n",
    "\n",
    "Let's start by specifying:\n",
    "\n",
    "- The S3 bucket and prefix that you want to use for training and model data.  This should be within the same region as the Notebook Instance, training, and hosting.\n",
    "- The IAM role arn used to give training and hosting access to your data. See the documentation for how to create these.  Note, if more than one role is required for notebook instances, training, and/or hosting, please replace the boto regexp with a the appropriate full IAM role arn string(s)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "isConfigCell": true,
    "tags": [
     "parameters"
    ]
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import botocore\n",
    "import sagemaker\n",
    "\n",
    "session = boto3.session.Session()\n",
    "region_name = session.region_name\n",
    "\n",
    "sagemaker_session = sagemaker.Session()\n",
    "bucket = sagemaker_session.default_bucket()\n",
    "\n",
    "prefix = 'sagemaker/factorization-machine-mnist'\n",
    " \n",
    "# Define IAM role\n",
    "import boto3\n",
    "import re\n",
    "from sagemaker import get_execution_role\n",
    "\n",
    "role = get_execution_role()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data ingestion\n",
    "\n",
    "Next, we read the dataset from an online URL into memory, for preprocessing prior to training. This processing could be done *in situ* by Amazon Athena, Apache Spark in Amazon EMR, Amazon Redshift, etc., assuming the dataset is present in the appropriate location. Then, the next step would be to transfer the data to S3 for use in training. For small datasets, such as this one, reading into memory isn't onerous, though it would be for larger datasets."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "CPU times: user 628 ms, sys: 319 ms, total: 946 ms\n",
      "Wall time: 2.16 s\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "import pickle, gzip, numpy, urllib.request, json\n",
    "\n",
    "# Load the dataset\n",
    "urllib.request.urlretrieve(\"http://deeplearning.net/data/mnist/mnist.pkl.gz\", \"mnist.pkl.gz\")\n",
    "with gzip.open('mnist.pkl.gz', 'rb') as f:\n",
    "    train_set, valid_set, test_set = pickle.load(f, encoding='latin1')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data inspection\n",
    "\n",
    "Once the dataset is imported, it's typical as part of the machine learning process to inspect the data, understand the distributions, and determine what type(s) of preprocessing might be needed. You can perform those tasks right here in the notebook. As an example, let's go ahead and look at one of the digits that is part of the dataset."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAJEAAACfCAYAAAD50jtTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDMuMC4zLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvnQurowAAB4pJREFUeJzt3V1oVvcdB/DvtzotumjIuimTOYmlERUX6JKAOKxzzkl11JeLZkgvdNKtCvZGxrwYuqFYfAGDwrwS3Fbr2I0vjClDoy1SafBlW6WyljkWsV2Hb0lMdEl+u8hTyO+QPC/7PSdP8jzfz435Js855x/y5Z+/J885h2YGkYhnSj0AGftUIglTiSRMJZIwlUjCVCIJq7gSkdxB8rdZvv4hyZcK3Od3SN4KD26MGl/qARQbyc5BcRKAJwD6Mvn1XNub2bxCj2lm7wKoK3S7QpCcBeAfALoGffotM/tVmsfNR9mVyMy+/MXHJG8D+LGZ/XnQ53aUYFjFVG1mvaUexGAV9+ssYwLJYyQ7Mr++vv3FF0jeJvm9zMeNJNtIPiL5GckDQ+2M5Esk2wfln5G8k9n/LZJLh9nuZZLXMvv/11gteKWW6IcA3gFQDeAUgEPDvO4ggINmNgXAbAC/z7VjknUAtgBoMLMqAMsB3B7m5V0AXsuM42UAPyX5So5D/JNkO8mjJJ/LNZ6RUKkles/M/mhmfQB+A+Bbw7zuvwCeJ/mcmXWa2ft57LsPwEQAc0l+ycxum9knQ73QzFrN7K9m1m9mfwFwHMDiYfb7HwANAL4J4EUAVQB+l8d4UlepJfp00MePATxLcqj14UYALwD4iOQHJFfm2rGZfQzgTQA7APyb5Dskvz7Ua0k2kbxA8nOSDwH8BMCQs0umxG1m1mtmn2Fgtvs+yapcY0pbpZYoL2b2dzNrBvA1AG8B+APJyXls97aZLcLArGGZbYfyNgZ+nX7DzKYC+DUA5ju8zL8l/xmWfACjGcn1JL9qZv0AHmQ+3Z9jmzqS3yU5EUAPgO4s21QBuGdmPSQbAfwoy36bMvt+huRXALQAaDWzh4V+X8WmEmX3AwAfZs49HQTwqpl159hmIoA9GFjDfIqBWeznw7z2DQC/JNkB4BfIvnCvBfAnAB0A/oaB81/NeX4fqaLelCZRmokkTCWSMJVIwlQiCVOJJGxE/4pPUv8VHMPMbMgToZqJJEwlkjCVSMJUIglTiSRMJZIwlUjCVCIJU4kkTCWSMJVIwlQiCVOJJEwlkjCVSMJUIglTiSRMJZIwlUjCyu5OaWkaN26cy1OnTi1o+y1btrg8adIkl+vq/B37Nm/e7PK+fftcbm72V1H39PS4vGfPHpd37tyZ/2ALoJlIwlQiCVOJJKyi1kQzZ850ecKECS4vXLjQ5UWLFrlcXV3t8tq1a4s4OqC9vd3llpYWl1evXu1yR0eHyzdu3HD54sWLRRzd8DQTSZhKJGEqkYSN6J3SRvpa/Pr6epfPnz/vcqHneYqtv9/fynHDhg0ud3Z2Ipu7d++6fP/+fZdv3Sru40Z0Lb6kRiWSMJVIwsp6TVRTU+PylStXXK6trS3q8ZL7f/DggctLlixx+enTpy6Xeo2Wi9ZEkhqVSMJUIgkr67+d3bt3z+Vt27a5vHKlf2jQtWvXXE7+7Srp+vXrLi9btszlrq4ul+fN808G3bp1a9b9jxWaiSRMJZIwlUjCyvo8US5TpkxxOfn+nCNHjri8ceNGl9evX+/y8ePHizi60UfniSQ1KpGEqUQSVtbniXJ59OhR1q8/fJj98aqbNm1y+cSJEy4n3y9UrjQTSZhKJGEqkYRV9HmiXCZPnuzy6dOnXV68eLHLK1ascPncuXPpDKxEdJ5IUqMSSZhKJGFaExVg9uzZLl+9etXl5HuqL1y44HJbW5vLhw8fdnkkfxb/D62JJDUqkYSpRBKmNVFA8n5BR48edbmqqirr9tu3b3f52LFjLievtS81rYkkNSqRhKlEEqY1URHNnz/f5QMHDri8dOnSrNsn39O9a9cul+/cuRMYXZzWRJIalUjCVCIJ05ooRcn7Xq9atcrl5Hkl0i85kveYTF7rP9K0JpLUqEQSphJJmNZEJfTkyROXx4/3lwH29va6vHz5cpdbW1tTGddwtCaS1KhEEqYSSVhFX4tfbAsWLHB53bp1Ljc0NLicXAMl3bx50+VLly4FRpcezUQSphJJmEokYVoTFSD53Prkc+7XrFnj8vTp0wvaf19fn8vJ91iP1vsdaSaSMJVIwlQiCdOaaJDkGqa5udnl5Bpo1qxZoeMlr81Pvqf61KlTof2PFM1EEqYSSZhKJGEVtSaaNm2ay3PnznX50KFDLs+ZMyd0vOQzYffu3evyyZMnXR6t54Fy0UwkYSqRhKlEElZWa6KamhqXk9e219fXu1xbWxs63uXLl13ev3+/y2fPnnW5u7s7dLzRSjORhKlEEqYSSdiYWhM1NTW5nHzOfWNjo8szZswIHe/x48cut7S0uLx7926Xu7q6QscbqzQTSZhKJGEqkYSNqTVR8r7RyZxL8jquM2fOuJy89j153if57A4ZoJlIwlQiCVOJJEz3J5K86f5EkhqVSMJUIglTiSRMJZIwlUjCVCIJU4kkTCWSMJVIwlQiCRvRv51JedJMJGEqkYSpRBKmEkmYSiRhKpGEqUQSphJJmEokYSqRhKlEEqYSSZhKJGEqkYSpRBKmEkmYSiRhKpGEqUQSphJJmEokYSqRhKlEEvY/fSsBuuo1S20AAAAASUVORK5CYII=\n",
      "text/plain": [
       "<Figure size 144x720 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "plt.rcParams[\"figure.figsize\"] = (2,10)\n",
    "\n",
    "\n",
    "def show_digit(img, caption='', subplot=None):\n",
    "    if subplot==None:\n",
    "        _,(subplot)=plt.subplots(1,1)\n",
    "    imgr=img.reshape((28,28))\n",
    "    subplot.axis('off')\n",
    "    subplot.imshow(imgr, cmap='gray')\n",
    "    plt.title(caption)\n",
    "\n",
    "show_digit(train_set[0][0], 'This is a {}'.format(train_set[1][0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data conversion\n",
    "\n",
    "Since algorithms have particular input and output requirements, converting the dataset is also part of the process that a data scientist goes through prior to initiating training. In this particular case, the Amazon SageMaker implementation of Factorization Machines takes recordIO-wrapped protobuf, where the data we have today is a pickle-ized numpy array on disk.\n",
    "\n",
    "Most of the conversion effort is handled by the Amazon SageMaker Python SDK, imported as `sagemaker` below.\n",
    "\n",
    "_Notice, despite the fact that most use cases for factorization machines will utilize spare input, we are writing our data out as dense tensors.  This will be fine since the MNIST dataset is not particularly large or high dimensional._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.01171875, 0.0703125 , 0.0703125 ,\n",
       "       0.0703125 , 0.4921875 , 0.53125   , 0.68359375, 0.1015625 ,\n",
       "       0.6484375 , 0.99609375, 0.96484375, 0.49609375, 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.1171875 , 0.140625  , 0.3671875 , 0.6015625 ,\n",
       "       0.6640625 , 0.98828125, 0.98828125, 0.98828125, 0.98828125,\n",
       "       0.98828125, 0.87890625, 0.671875  , 0.98828125, 0.9453125 ,\n",
       "       0.76171875, 0.25      , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.19140625, 0.9296875 ,\n",
       "       0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.98828125,\n",
       "       0.98828125, 0.98828125, 0.98828125, 0.98046875, 0.36328125,\n",
       "       0.3203125 , 0.3203125 , 0.21875   , 0.15234375, 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.0703125 , 0.85546875, 0.98828125, 0.98828125,\n",
       "       0.98828125, 0.98828125, 0.98828125, 0.7734375 , 0.7109375 ,\n",
       "       0.96484375, 0.94140625, 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.3125    , 0.609375  , 0.41796875, 0.98828125, 0.98828125,\n",
       "       0.80078125, 0.04296875, 0.        , 0.16796875, 0.6015625 ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.0546875 ,\n",
       "       0.00390625, 0.6015625 , 0.98828125, 0.3515625 , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.54296875,\n",
       "       0.98828125, 0.7421875 , 0.0078125 , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.04296875, 0.7421875 , 0.98828125,\n",
       "       0.2734375 , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.13671875, 0.94140625, 0.87890625, 0.625     ,\n",
       "       0.421875  , 0.00390625, 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.31640625, 0.9375    , 0.98828125, 0.98828125, 0.46484375,\n",
       "       0.09765625, 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.17578125,\n",
       "       0.7265625 , 0.98828125, 0.98828125, 0.5859375 , 0.10546875,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.0625    , 0.36328125,\n",
       "       0.984375  , 0.98828125, 0.73046875, 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.97265625, 0.98828125,\n",
       "       0.97265625, 0.25      , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.1796875 , 0.5078125 ,\n",
       "       0.71484375, 0.98828125, 0.98828125, 0.80859375, 0.0078125 ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.15234375,\n",
       "       0.578125  , 0.89453125, 0.98828125, 0.98828125, 0.98828125,\n",
       "       0.9765625 , 0.7109375 , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.09375   , 0.4453125 , 0.86328125, 0.98828125, 0.98828125,\n",
       "       0.98828125, 0.98828125, 0.78515625, 0.3046875 , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.08984375, 0.2578125 , 0.83203125, 0.98828125,\n",
       "       0.98828125, 0.98828125, 0.98828125, 0.7734375 , 0.31640625,\n",
       "       0.0078125 , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.0703125 , 0.66796875, 0.85546875,\n",
       "       0.98828125, 0.98828125, 0.98828125, 0.98828125, 0.76171875,\n",
       "       0.3125    , 0.03515625, 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.21484375, 0.671875  ,\n",
       "       0.8828125 , 0.98828125, 0.98828125, 0.98828125, 0.98828125,\n",
       "       0.953125  , 0.51953125, 0.04296875, 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.53125   , 0.98828125, 0.98828125, 0.98828125,\n",
       "       0.828125  , 0.52734375, 0.515625  , 0.0625    , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        , 0.        ,\n",
       "       0.        , 0.        , 0.        , 0.        ], dtype=float32)"
      ]
     },
     "execution_count": 12,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import io\n",
    "import numpy as np\n",
    "\n",
    "vectors = np.array([t.tolist() for t in train_set[0]]).astype('float32')\n",
    "vectors[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([5, 0, 4, ..., 8, 4, 8])"
      ]
     },
     "execution_count": 30,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.array([t.tolist() for t in train_set[1]])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "50000"
      ]
     },
     "execution_count": 46,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(np.where(np.array([t.tolist() for t in train_set[1]]) == 0, 1.0, 0.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "array([0., 1., 0., ..., 0., 0., 0.], dtype=float32)"
      ]
     },
     "execution_count": 47,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "labels = np.where(np.array([t.tolist() for t in train_set[1]]) == 0, 1.0, 0.0).astype('float32')\n",
    "labels"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import sagemaker.amazon.common as smac\n",
    "\n",
    "buf = io.BytesIO()\n",
    "smac.write_numpy_to_dense_tensor(buf, vectors, labels)\n",
    "buf.seek(0)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Upload training data\n",
    "Now that we've created our recordIO-wrapped protobuf, we'll need to upload it to S3, so that Amazon SageMaker training can use it."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    },
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import os\n",
    "\n",
    "key = 'recordio-pb-data'\n",
    "boto3.resource('s3').Bucket(bucket).Object(os.path.join(prefix, 'train', key)).upload_fileobj(buf)\n",
    "s3_train_data = 's3://{}/{}/train/{}'.format(bucket, prefix, key)\n",
    "print('uploaded training data location: {}'.format(s3_train_data))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Let's also setup an output S3 location for the model artifact that will be output as the result of training with the algorithm."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "output_location = 's3://{}/{}/output'.format(bucket, prefix)\n",
    "print('training artifacts will be uploaded to: {}'.format(output_location))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Training the factorization machine model\n",
    "\n",
    "Once we have the data preprocessed and available in the correct format for training, the next step is to actually train the model using the data. Since this data is relatively small, it isn't meant to show off the performance of the Amazon SageMaker's Factorization Machines in training, although we have tested it on multi-terabyte datasets.\n",
    "\n",
    "Again, we'll use the Amazon SageMaker Python SDK to kick off training and monitor status until it is completed.  In this example that takes between 7 and 11 minutes.  Despite the dataset being small, provisioning hardware and loading the algorithm container take time upfront.\n",
    "\n",
    "First, let's specify our containers.  Since we want this notebook to run in all 4 of Amazon SageMaker's regions, we'll create a small lookup.  More details on algorithm containers can be found in [AWS documentation](https://docs-aws.amazon.com/sagemaker/latest/dg/sagemaker-algo-docker-registry-paths.html)."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "from sagemaker.amazon.amazon_estimator import get_image_uri\n",
    "container = get_image_uri(boto3.Session().region_name, 'factorization-machines')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Next we'll kick off the base estimator, making sure to pass in the necessary hyperparameters.  Notice:\n",
    "- `feature_dim` is set to 784, which is the number of pixels in each 28 x 28 image.\n",
    "- `predictor_type` is set to 'binary_classifier' since we are trying to predict whether the image is or is not a 0.\n",
    "- `mini_batch_size` is set to 200.  This value can be tuned for relatively minor improvements in fit and speed, but selecting a reasonable value relative to the dataset is appropriate in most cases.\n",
    "- `num_factors` is set to 10.  As mentioned initially, factorization machines find a lower dimensional representation of the interactions for all features.  Making this value smaller provides a more parsimonious model, closer to a linear model, but may sacrifice information about interactions.  Making it larger provides a higher-dimensional representation of feature interactions, but adds computational complexity and can lead to overfitting.  In a practical application, time should be invested to tune this parameter to the appropriate value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "\n",
    "sess = sagemaker.Session()\n",
    "\n",
    "fm = sagemaker.estimator.Estimator(container,\n",
    "                                   role, \n",
    "                                   train_instance_count=1, \n",
    "                                   train_instance_type='ml.c4.xlarge',\n",
    "                                   output_path=output_location,\n",
    "                                   sagemaker_session=sess)\n",
    "fm.set_hyperparameters(feature_dim=784,\n",
    "                      predictor_type='binary_classifier',\n",
    "                      mini_batch_size=200,\n",
    "                      num_factors=10)\n",
    "\n",
    "fm.fit({'train': s3_train_data})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Set up hosting for the model\n",
    "Now that we've trained our model, we can deploy it behind an Amazon SageMaker real-time hosted endpoint.  This will allow out to make predictions (or inference) from the model dyanamically.\n",
    "\n",
    "_Note, Amazon SageMaker allows you the flexibility of importing models trained elsewhere, as well as the choice of not importing models if the target of model creation is AWS Lambda, AWS Greengrass, Amazon Redshift, Amazon Athena, or other deployment target._"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "fm_predictor = fm.deploy(initial_instance_count=1,\n",
    "                         instance_type='ml.m4.xlarge')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Validate the model for use\n",
    "Finally, we can now validate the model for use.  We can pass HTTP POST requests to the endpoint to get back predictions.  To make this easier, we'll again use the Amazon SageMaker Python SDK and specify how to serialize requests and deserialize responses that are specific to the algorithm.\n",
    "\n",
    "Since factorization machines are so frequently used with sparse data, making inference requests with a CSV format (as is done in other algorithm examples) can be massively inefficient.  Rather than waste space and time generating all of those zeros, to pad the row to the correct dimensionality, JSON can be used more efficiently.  Since we trained the model using dense data, this is a bit of a moot point, as we'll have to pass all the 0s in anyway.\n",
    "\n",
    "Nevertheless, we'll write our own small function to serialize our inference request in the JSON format that Amazon SageMaker Factorization Machines expects."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import json\n",
    "from sagemaker.predictor import json_deserializer\n",
    "\n",
    "def fm_serializer(data):\n",
    "    js = {'instances': []}\n",
    "    for row in data:\n",
    "        js['instances'].append({'features': row.tolist()})\n",
    "    return json.dumps(js)\n",
    "\n",
    "fm_predictor.content_type = 'application/json'\n",
    "fm_predictor.serializer = fm_serializer\n",
    "fm_predictor.deserializer = json_deserializer"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now let's try getting a prediction for a single record."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "result = fm_predictor.predict(train_set[0][30:31])\n",
    "print(result)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "OK, a single prediction works.  We see that for one record our endpoint returned some JSON which contains `predictions`, including the `score` and `predicted_label`.  In this case, `score` will be a continuous value between [0, 1] representing the probability we think the digit is a 0 or not.  `predicted_label` will take a value of either `0` or `1` where (somewhat counterintuitively) `1` denotes that we predict the image is a 0, while `0` denotes that we are predicting the image is not of a 0.\n",
    "\n",
    "Let's do a whole batch of images and evaluate our predictive accuracy."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "predictions = []\n",
    "for array in np.array_split(test_set[0], 100):\n",
    "    result = fm_predictor.predict(array)\n",
    "    predictions += [r['predicted_label'] for r in result['predictions']]\n",
    "\n",
    "predictions = np.array(predictions)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import pandas as pd\n",
    "\n",
    "pd.crosstab(np.where(test_set[1] == 0, 1, 0), predictions, rownames=['actuals'], colnames=['predictions'])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we can see from the confusion matrix above, we predict 951 images of the digit 0 correctly (confusingly this is class 1).  Meanwhile we predict 165 images as the digit 0 when in actuality they aren't, and we miss predicting 29 images of the digit 0 that we should have.\n",
    "\n",
    "*Note: Due to some differences in parameter initialization, your results may differ from those listed above, but should remain reasonably consistent.*"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### (Optional) Delete the Endpoint\n",
    "\n",
    "If you're ready to be done with this notebook, please run the delete_endpoint line in the cell below.  This will remove the hosted endpoint you created and avoid any charges from a stray instance being left on."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "collapsed": true,
    "jupyter": {
     "outputs_hidden": true
    }
   },
   "outputs": [],
   "source": [
    "import sagemaker\n",
    "\n",
    "sagemaker.Session().delete_endpoint(fm_predictor.endpoint)"
   ]
  }
 ],
 "metadata": {
  "celltoolbar": "Tags",
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  },
  "notice": "Copyright 2017 Amazon.com, Inc. or its affiliates. All Rights Reserved.  Licensed under the Apache License, Version 2.0 (the \"License\"). You may not use this file except in compliance with the License. A copy of the License is located at http://aws.amazon.com/apache2.0/ or in the \"license\" file accompanying this file. This file is distributed on an \"AS IS\" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License."
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
